-
Notifications
You must be signed in to change notification settings - Fork 15.2k
[MLIR][XeGPU] Allow create mem desc from 2d memref #167767
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[MLIR][XeGPU] Allow create mem desc from 2d memref #167767
Conversation
|
@llvm/pr-subscribers-mlir-gpu Author: Jianhui Li (Jianhui-Li) ChangesThis PR relax the create_mem_desc's restriction on source memref, allowing it to be a nd memref. Full diff: https://github.com/llvm/llvm-project/pull/167767.diff 4 Files Affected:
diff --git a/mlir/include/mlir/Dialect/XeGPU/IR/XeGPUOps.td b/mlir/include/mlir/Dialect/XeGPU/IR/XeGPUOps.td
index 689ebd0d1179a..59dc0eade0744 100644
--- a/mlir/include/mlir/Dialect/XeGPU/IR/XeGPUOps.td
+++ b/mlir/include/mlir/Dialect/XeGPU/IR/XeGPUOps.td
@@ -1308,7 +1308,7 @@ def XeGPU_CreateMemDescOp: XeGPU_Op<"create_mem_desc", [Pure,
Results:
- `mem_desc` : the memory descriptor.
}];
- let arguments = (ins StaticShared1DMemRefOf<[I8]>:$source);
+ let arguments = (ins AnyTypeOf<[StaticShared1DMemRefOf<[I8]>, ConfinedType<MemRefRankOf<[XeGPU_ScalarType], [2]>, [HasStaticShapePred, isSharedPred]>]>:$source);
let results = (outs XeGPU_MemDesc:$mem_desc);
let assemblyFormat = "$source prop-dict attr-dict `` `:` type($source) `->` qualified(type($mem_desc))";
}
diff --git a/mlir/lib/Conversion/XeGPUToXeVM/XeGPUToXeVM.cpp b/mlir/lib/Conversion/XeGPUToXeVM/XeGPUToXeVM.cpp
index 705298f497d20..c01078ca6938e 100644
--- a/mlir/lib/Conversion/XeGPUToXeVM/XeGPUToXeVM.cpp
+++ b/mlir/lib/Conversion/XeGPUToXeVM/XeGPUToXeVM.cpp
@@ -590,16 +590,12 @@ class CreateMemDescOpPattern final
matchAndRewrite(xegpu::CreateMemDescOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
- auto resTy = op.getMemDesc();
+ Value baseAddr = memref::ExtractAlignedPointerAsIndexOp::create(
+ rewriter, op.getLoc(), op.getSource());
+ auto baseAddr32 = arith::IndexCastUIOp::create(
+ rewriter, op.getLoc(), rewriter.getI32Type(), baseAddr);
- // Create the result MemRefType with the same shape, element type, and
- // memory space
- auto newResTy = getTypeConverter()->convertType<MemRefType>(resTy);
-
- Value zero = arith::ConstantIndexOp::create(rewriter, op.getLoc(), 0);
- auto viewOp = memref::ViewOp::create(rewriter, op.getLoc(), newResTy,
- op.getSource(), zero, ValueRange());
- rewriter.replaceOp(op, viewOp);
+ rewriter.replaceOp(op, baseAddr32);
return success();
}
};
@@ -619,7 +615,7 @@ class LoadStoreMatrixToXeVMPattern : public OpConversionPattern<OpType> {
auto loc = op.getLoc();
auto ctxt = rewriter.getContext();
- Value basePtrStruct = adaptor.getMemDesc();
+ Value baseAddr32 = adaptor.getMemDesc();
Value mdescVal = op.getMemDesc();
// Load result or Store value Type can be vector or scalar.
Value data;
@@ -647,21 +643,14 @@ class LoadStoreMatrixToXeVMPattern : public OpConversionPattern<OpType> {
auto mdescTy = cast<xegpu::MemDescType>(mdescVal.getType());
- Value basePtrLLVM = memref::ExtractAlignedPointerAsIndexOp::create(
- rewriter, loc, basePtrStruct);
-
- // Convert base pointer (ptr) to i32
- Value basePtrI32 = arith::IndexCastUIOp::create(
- rewriter, loc, rewriter.getI32Type(), basePtrLLVM);
-
Value linearOffset = mdescTy.getLinearOffsets(rewriter, loc, offsets);
linearOffset = arith::IndexCastUIOp::create(
rewriter, loc, rewriter.getI32Type(), linearOffset);
- basePtrI32 = addOffsetToBaseAddr(rewriter, loc, basePtrI32, linearOffset,
- elemByteSize);
+ Value basePtrI32 = addOffsetToBaseAddr(rewriter, loc, baseAddr32,
+ linearOffset, elemByteSize);
// convert base pointer (i32) to LLVM pointer type
- basePtrLLVM =
+ Value basePtrLLVM =
LLVM::IntToPtrOp::create(rewriter, loc, ptrTypeLLVM, basePtrI32);
if (op.getSubgroupBlockIoAttr()) {
@@ -1005,11 +994,9 @@ struct ConvertXeGPUToXeVMPass
auto i32Type = IntegerType::get(&getContext(), 32);
return VectorType::get(8, i32Type);
});
- // Convert MemDescType into flattened MemRefType for SLM
+ // Convert MemDescType into i32 for SLM
typeConverter.addConversion([&](xegpu::MemDescType type) -> Type {
- Type elemTy = type.getElementType();
- int numElems = type.getNumElements();
- return MemRefType::get(numElems, elemTy, AffineMap(), 3);
+ return IntegerType::get(&getContext(), 32);
});
typeConverter.addConversion([&](MemRefType type) -> Type {
diff --git a/mlir/test/Conversion/XeGPUToXeVM/loadstore_matrix.mlir b/mlir/test/Conversion/XeGPUToXeVM/loadstore_matrix.mlir
index d4cb493271d0d..5b3776ba0039c 100644
--- a/mlir/test/Conversion/XeGPUToXeVM/loadstore_matrix.mlir
+++ b/mlir/test/Conversion/XeGPUToXeVM/loadstore_matrix.mlir
@@ -4,8 +4,8 @@ gpu.module @test_kernel [#xevm.target<chip = "pvc">] {
// e.g. for mem_desc<32x32xf16, @strides=[1, 16]>
// its memory layout tuple is (blocked shape = [1,1,32,32],strides=[1024,1024,32,1])
- //CHECK-LABEL: load_store_matrix_1
- gpu.func @load_store_matrix_1(%arg0: memref<4096xi8, 3>) -> f32 {
+ //CHECK-LABEL: load_store_matrix_plain
+ gpu.func @load_store_matrix_plain(%arg0: memref<4096xi8, 3>) -> f32 {
%0 = xegpu.create_mem_desc %arg0 : memref<4096xi8, 3> -> !xegpu.mem_desc<32x32xf32>
//CHECK: %[[TID:.*]] = gpu.thread_id x
@@ -26,12 +26,40 @@ gpu.module @test_kernel [#xevm.target<chip = "pvc">] {
gpu.return %1: f32
}
+ //CHECK-LABEL: load_store_matrix_plain_2d_input
+ gpu.func @load_store_matrix_plain_2d_input(%arg0: memref<8192xi8, 3>) -> f32 {
+ %c0 = arith.constant 0 : index
+ %view = memref.view %arg0[%c0][]: memref<8192xi8, 3> to memref<64x32xf32, 3>
+
+ %subview = memref.subview %view[32, 0] [32, 32] [1, 1] : memref<64x32xf32, 3> to memref<32x32xf32, strided<[32, 1], offset: 1024>, 3>
+
+ %0 = xegpu.create_mem_desc %subview : memref<32x32xf32, strided<[32, 1], offset: 1024>, 3> -> !xegpu.mem_desc<32x32xf32>
+
+ //CHECK: %[[TID:.*]] = gpu.thread_id x
+ //CHECK: %[[C1:.*]] = arith.constant 1 : index
+ //CHECK: %[[MUL1:.*]] = arith.muli %[[TID]], %[[C1]] : index
+ //CHECK: %[[C4:.*]] = arith.constant 4 : i32
+ //CHECK: %[[MUL2:.*]] = arith.muli {{.*}}, %[[C4]] : i32
+ //CHECK: llvm.load {{.*}} : !llvm.ptr<3> -> f32
+
+ %tid_x = gpu.thread_id x
+
+ %1 = xegpu.load_matrix %0[%c0, %tid_x]: !xegpu.mem_desc<32x32xf32>, index, index -> f32
+
+ //CHECK: llvm.store {{.*}}, {{.*}} : f32, !llvm.ptr<3>
+
+ xegpu.store_matrix %1, %0[%c0, %tid_x]: f32, !xegpu.mem_desc<32x32xf32>, index, index
+
+ gpu.return %1: f32
+ }
+
+
// e.g. for mem_desc<32x64xf16, @block=[16, 16], @strides=[1, 32]>
// its memory layout tuple is ([2,4,16,16],[256,512,1,16])
- //CHECK-LABEL: load_store_matrix_2
- gpu.func @load_store_matrix_2(%arg0: memref<4096xi8, 3>) -> f16 {
+ //CHECK-LABEL: load_store_matrix_blocked_strided
+ gpu.func @load_store_matrix_blocked_strided(%arg0: memref<4096xi8, 3>) -> f16 {
%0 = xegpu.create_mem_desc %arg0 : memref<4096xi8, 3> -> !xegpu.mem_desc<32x64xf16, #xegpu.mem_layout<stride = [1, 32], block = [16, 16]>>
- //CHECK: %[[c0:.*]] = arith.constant 0 : index
+
//CHECK: %[[tid_x:.*]] = gpu.thread_id x
//CHECK: %[[c13:.*]] = arith.constant 13 : index
//CHECK: %[[c16:.*]] = arith.constant 16 : index
@@ -39,7 +67,7 @@ gpu.module @test_kernel [#xevm.target<chip = "pvc">] {
//CHECK: %[[offsetx_1:.*]] = arith.remsi %[[c13]], %[[c16]] : index
//CHECK: %[[offsety_0:.*]] = arith.divsi %[[tid_x]], %[[c16]] : index
//CHECK: %[[offsety_1:.*]] = arith.remsi %[[tid_x]], %[[c16]] : index
-
+ //CHECK: %[[c0:.*]] = arith.constant 0 : index
//CHECK: %[[c256:.*]] = arith.constant 256 : index
//CHECK: %[[mul0:.*]] = arith.muli %[[offsetx_0]], %[[c256]] : index
//CHECK: %[[add0:.*]] = arith.addi %[[mul0]], %[[c0]] : index
@@ -68,10 +96,11 @@ gpu.module @test_kernel [#xevm.target<chip = "pvc">] {
// e.g. for mem_desc<32x64xf16, @block=[16, 16]>
// its memory layout tuple is ([2,4,16,16],[1024,256,16,1])
- //CHECK-LABEL: load_store_matrix_3
- gpu.func @load_store_matrix_3(%arg0: memref<4096xi8, 3>) -> f16 {
- //CHECK: %[[c0:.*]] = arith.constant 0 : index
- //CHECK: %[[view:.*]] = memref.view %arg0[%[[c0]]][] : memref<4096xi8, 3> to memref<2048xf16, 3>
+ //CHECK-LABEL: load_store_matrix_blocked_nostride
+ gpu.func @load_store_matrix_blocked_nostride(%arg0: memref<4096xi8, 3>) -> f16 {
+
+ //CHECK: %[[intptr:.*]] = memref.extract_aligned_pointer_as_index %arg0 : memref<4096xi8, 3> -> index
+ //CHECK: %[[basePtrI64:.*]] = arith.index_castui %[[intptr]] : index to i32
%0 = xegpu.create_mem_desc %arg0 : memref<4096xi8, 3> -> !xegpu.mem_desc<32x64xf16, #xegpu.mem_layout<block = [16, 16]>>
//CHECK: %[[tid_x:.*]] = gpu.thread_id x
@@ -79,13 +108,12 @@ gpu.module @test_kernel [#xevm.target<chip = "pvc">] {
%tid_x = gpu.thread_id x
%c19 = arith.constant 19: index
- //CHECK: %[[intptr:.*]] = memref.extract_aligned_pointer_as_index %[[view]] : memref<2048xf16, 3> -> index
- //CHECK: %[[basePtrI64:.*]] = arith.index_castui %[[intptr]] : index to i32
//CHECK: %[[c16:.*]] = arith.constant 16 : index
//CHECK: %[[offsetx_0:.*]] = arith.divsi %[[c19]], %[[c16]] : index
//CHECK: %[[offsetx_1:.*]] = arith.remsi %[[c19]], %[[c16]] : index
//CHECK: %[[offsety_0:.*]] = arith.divsi %[[tid_x]], %[[c16]] : index
//CHECK: %[[offsety_1:.*]] = arith.remsi %[[tid_x]], %[[c16]] : index
+ //CHECK: %[[c0:.*]] = arith.constant 0 : index
//CHECK: %[[c1024:.*]] = arith.constant 1024 : index
//CHECK: %[[mul0:.*]] = arith.muli %[[offsetx_0]], %[[c1024]] : index
//CHECK: %[[add0:.*]] = arith.addi %[[mul0]], %[[c0]] : index
@@ -97,7 +125,6 @@ gpu.module @test_kernel [#xevm.target<chip = "pvc">] {
//CHECK: %[[c1:.*]] = arith.constant 1 : index
//CHECK: %[[mul3:.*]] = arith.muli %[[offsety_1]], %[[c1]] : index
//CHECK: %[[add3:.*]] = arith.addi %[[mul3]], %[[add2]] : index
-
//CHECK: %[[loaded:.*]] = llvm.load {{.*}} : !llvm.ptr<3> -> f16
%1 = xegpu.load_matrix %0[%c19, %tid_x]: !xegpu.mem_desc<32x64xf16, #xegpu.mem_layout<block = [16, 16]>>, index, index -> f16
@@ -110,19 +137,17 @@ gpu.module @test_kernel [#xevm.target<chip = "pvc">] {
// e.g. for mem_desc<32x64xf16, @block=[16, 16], @strides=[1, 16]>
// its memory layout tuple is ([2,4,16,16],[256,512,1,16])
- //CHECK-LABEL: load_store_matrix_4
- gpu.func @load_store_matrix_4(%arg0: memref<4096xi8, 3>) -> vector<8xf16> {
+ //CHECK-LABEL: load_store_matrix_blocked_strided_return_vector
+ gpu.func @load_store_matrix_blocked_strided_return_vector(%arg0: memref<4096xi8, 3>) -> vector<8xf16> {
%0 = xegpu.create_mem_desc %arg0 : memref<4096xi8, 3> -> !xegpu.mem_desc<32x64xf16, #xegpu.mem_layout<stride = [1, 32], block = [16, 16]>>
- //CHECK: %[[c0:.*]] = arith.constant 0 : index
//CHECK: %[[tid_x:.*]] = gpu.thread_id x
-
//CHECK: %[[c16:.*]] = arith.constant 16 : index
//CHECK: %[[offsetx_0:.*]] = arith.divsi %[[c16]], %[[c16]] : index
//CHECK: %[[offsetx_1:.*]] = arith.remsi %[[c16]], %[[c16]] : index
//CHECK: %[[offsety_0:.*]] = arith.divsi %[[tid_x]], %[[c16]] : index
//CHECK: %[[offsety_1:.*]] = arith.remsi %[[tid_x]], %[[c16]] : index
-
+ //CHECK: %[[c0:.*]] = arith.constant 0 : index
//CHECK: %[[c256:.*]] = arith.constant 256 : index
//CHECK: %[[mul0:.*]] = arith.muli %[[offsetx_0]], %[[c256]] : index
//CHECK: %[[add0:.*]] = arith.addi %[[mul0]], %[[c0]] : index
@@ -150,25 +175,24 @@ gpu.module @test_kernel [#xevm.target<chip = "pvc">] {
// e.g. for mem_desc<32x64xf16, @block=[16, 16]>
// its memory layout tuple is ([2,4,16,16],[1024,256,16,1])
- //CHECK-LABEL: load_store_matrix_5
- gpu.func @load_store_matrix_5(%arg0: memref<4096xi8, 3>) -> vector<8xf16> {
- //CHECK: %[[c0:.*]] = arith.constant 0 : index
- //CHECK: %[[view:.*]] = memref.view %arg0[%[[c0]]][] : memref<4096xi8, 3> to memref<2048xf16, 3>
-
- %0 = xegpu.create_mem_desc %arg0 : memref<4096xi8, 3> -> !xegpu.mem_desc<32x64xf16, #xegpu.mem_layout<block = [16, 16]>>
-
+ //CHECK-LABEL: load_store_matrix_blocked_subgroupblockio
+ gpu.func @load_store_matrix_blocked_subgroupblockio(%arg0: memref<4096xi8, 3>) -> vector<8xf16> {
+
+ //CHECK: %[[intptr:.*]] = memref.extract_aligned_pointer_as_index %arg0 : memref<4096xi8, 3> -> index
+ //CHECK: %[[basePtrI64:.*]] = arith.index_castui %[[intptr]] : index to i32
+ %0 = xegpu.create_mem_desc %arg0 : memref<4096xi8, 3> -> !xegpu.mem_desc<32x64xf16, #xegpu.mem_layout<block = [16, 16]>>
+
+
//CHECK: %[[c16:.*]] = arith.constant 16 : index
//CHECK: %[[c48:.*]] = arith.constant 48 : index
-
%c16 = arith.constant 16 : index
%c48 = arith.constant 48 : index
- //CHECK: %[[intptr:.*]] = memref.extract_aligned_pointer_as_index %[[view]] : memref<2048xf16, 3> -> index
- //CHECK: %[[basePtrI64:.*]] = arith.index_castui %[[intptr]] : index to i32
//CHECK: %[[offset0:.*]] = arith.divsi %[[c16]], %[[c16]] : index
//CHECK: %[[offset1:.*]] = arith.remsi %[[c16]], %[[c16]] : index
//CHECK: %[[offset2:.*]] = arith.divsi %[[c48]], %[[c16]] : index
//CHECK: %[[offset3:.*]] = arith.remsi %[[c48]], %[[c16]] : index
+ //CHECK: %[[c0:.*]] = arith.constant 0 : index
//CHECK: %[[c1024:.*]] = arith.constant 1024 : index
//CHECK: %[[mul0:.*]] = arith.muli %[[offset0]], %[[c1024]] : index
//CHECK: %[[add0:.*]] = arith.addi %[[mul0]], %[[c0]] : index
diff --git a/mlir/test/Dialect/XeGPU/ops.mlir b/mlir/test/Dialect/XeGPU/ops.mlir
index 9b3829664108d..1e9738f44bb66 100644
--- a/mlir/test/Dialect/XeGPU/ops.mlir
+++ b/mlir/test/Dialect/XeGPU/ops.mlir
@@ -834,6 +834,27 @@ gpu.func @create_mem_desc_with_stride() {
gpu.return
}
+
+// CHECK-LABEL: gpu.func @create_mem_desc_from_2d_memref({{.*}}) {
+gpu.func @create_mem_desc_from_2d_memref() {
+ //CHECK: [[alloc:%.+]] = memref.alloca() {alignment = 1024 : i64} : memref<16x64xf16, 3>
+ //CHECK: [[mdesc:%.+]] = xegpu.create_mem_desc [[alloc]] : memref<16x64xf16, 3> -> !xegpu.mem_desc<16x64xf16>
+ %m = memref.alloca() {alignment = 1024} : memref<16x64xf16, 3>
+ %mem_desc = xegpu.create_mem_desc %m : memref<16x64xf16, 3> -> !xegpu.mem_desc<16x64xf16>
+ gpu.return
+}
+
+// CHECK-LABEL: gpu.func @create_mem_desc_with_stride_from_2d_memref({{.*}}) {
+gpu.func @create_mem_desc_with_stride_from_2d_memref() {
+ //CHECK: %[[ALLOC:.+]] = memref.alloca() {alignment = 1024 : i64} : memref<32x64xf16, 3>
+ //CHECK: %[[SUBVIEW:.+]] = memref.subview %[[ALLOC]][16, 0] [16, 64] [1, 1] : memref<32x64xf16, 3> to memref<16x64xf16, strided<[64, 1], offset: 1024>, 3>
+ //CHECK: %{{.+}} = xegpu.create_mem_desc %[[SUBVIEW]] : memref<16x64xf16, strided<[64, 1], offset: 1024>, 3> -> !xegpu.mem_desc<16x64xf16, #xegpu.mem_layout<stride = [1, 16]>>
+ %m = memref.alloca() {alignment = 1024} : memref<32x64xf16, 3>
+ %m_sub = memref.subview %m[16, 0][16, 64][1,1] : memref<32x64xf16, 3> to memref<16x64xf16, strided<[64, 1], offset: 1024>, 3>
+ %mem_desc = xegpu.create_mem_desc %m_sub : memref<16x64xf16, strided<[64, 1], offset: 1024>, 3> -> !xegpu.mem_desc<16x64xf16, #xegpu.mem_layout<stride = [1, 16]>>
+ gpu.return
+}
+
// CHECK: gpu.func @load_matrix([[ARG0:%.+]]: !xegpu.mem_desc<16x64xf16>)
gpu.func @load_matrix(%arg0: !xegpu.mem_desc<16x64xf16>) {
// CHECK: xegpu.load_matrix [[ARG0]][8, 8] : !xegpu.mem_desc<16x64xf16> -> vector<8x16xf16>
|
|
@llvm/pr-subscribers-mlir Author: Jianhui Li (Jianhui-Li) ChangesThis PR relax the create_mem_desc's restriction on source memref, allowing it to be a nd memref. Full diff: https://github.com/llvm/llvm-project/pull/167767.diff 4 Files Affected:
diff --git a/mlir/include/mlir/Dialect/XeGPU/IR/XeGPUOps.td b/mlir/include/mlir/Dialect/XeGPU/IR/XeGPUOps.td
index 689ebd0d1179a..59dc0eade0744 100644
--- a/mlir/include/mlir/Dialect/XeGPU/IR/XeGPUOps.td
+++ b/mlir/include/mlir/Dialect/XeGPU/IR/XeGPUOps.td
@@ -1308,7 +1308,7 @@ def XeGPU_CreateMemDescOp: XeGPU_Op<"create_mem_desc", [Pure,
Results:
- `mem_desc` : the memory descriptor.
}];
- let arguments = (ins StaticShared1DMemRefOf<[I8]>:$source);
+ let arguments = (ins AnyTypeOf<[StaticShared1DMemRefOf<[I8]>, ConfinedType<MemRefRankOf<[XeGPU_ScalarType], [2]>, [HasStaticShapePred, isSharedPred]>]>:$source);
let results = (outs XeGPU_MemDesc:$mem_desc);
let assemblyFormat = "$source prop-dict attr-dict `` `:` type($source) `->` qualified(type($mem_desc))";
}
diff --git a/mlir/lib/Conversion/XeGPUToXeVM/XeGPUToXeVM.cpp b/mlir/lib/Conversion/XeGPUToXeVM/XeGPUToXeVM.cpp
index 705298f497d20..c01078ca6938e 100644
--- a/mlir/lib/Conversion/XeGPUToXeVM/XeGPUToXeVM.cpp
+++ b/mlir/lib/Conversion/XeGPUToXeVM/XeGPUToXeVM.cpp
@@ -590,16 +590,12 @@ class CreateMemDescOpPattern final
matchAndRewrite(xegpu::CreateMemDescOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
- auto resTy = op.getMemDesc();
+ Value baseAddr = memref::ExtractAlignedPointerAsIndexOp::create(
+ rewriter, op.getLoc(), op.getSource());
+ auto baseAddr32 = arith::IndexCastUIOp::create(
+ rewriter, op.getLoc(), rewriter.getI32Type(), baseAddr);
- // Create the result MemRefType with the same shape, element type, and
- // memory space
- auto newResTy = getTypeConverter()->convertType<MemRefType>(resTy);
-
- Value zero = arith::ConstantIndexOp::create(rewriter, op.getLoc(), 0);
- auto viewOp = memref::ViewOp::create(rewriter, op.getLoc(), newResTy,
- op.getSource(), zero, ValueRange());
- rewriter.replaceOp(op, viewOp);
+ rewriter.replaceOp(op, baseAddr32);
return success();
}
};
@@ -619,7 +615,7 @@ class LoadStoreMatrixToXeVMPattern : public OpConversionPattern<OpType> {
auto loc = op.getLoc();
auto ctxt = rewriter.getContext();
- Value basePtrStruct = adaptor.getMemDesc();
+ Value baseAddr32 = adaptor.getMemDesc();
Value mdescVal = op.getMemDesc();
// Load result or Store value Type can be vector or scalar.
Value data;
@@ -647,21 +643,14 @@ class LoadStoreMatrixToXeVMPattern : public OpConversionPattern<OpType> {
auto mdescTy = cast<xegpu::MemDescType>(mdescVal.getType());
- Value basePtrLLVM = memref::ExtractAlignedPointerAsIndexOp::create(
- rewriter, loc, basePtrStruct);
-
- // Convert base pointer (ptr) to i32
- Value basePtrI32 = arith::IndexCastUIOp::create(
- rewriter, loc, rewriter.getI32Type(), basePtrLLVM);
-
Value linearOffset = mdescTy.getLinearOffsets(rewriter, loc, offsets);
linearOffset = arith::IndexCastUIOp::create(
rewriter, loc, rewriter.getI32Type(), linearOffset);
- basePtrI32 = addOffsetToBaseAddr(rewriter, loc, basePtrI32, linearOffset,
- elemByteSize);
+ Value basePtrI32 = addOffsetToBaseAddr(rewriter, loc, baseAddr32,
+ linearOffset, elemByteSize);
// convert base pointer (i32) to LLVM pointer type
- basePtrLLVM =
+ Value basePtrLLVM =
LLVM::IntToPtrOp::create(rewriter, loc, ptrTypeLLVM, basePtrI32);
if (op.getSubgroupBlockIoAttr()) {
@@ -1005,11 +994,9 @@ struct ConvertXeGPUToXeVMPass
auto i32Type = IntegerType::get(&getContext(), 32);
return VectorType::get(8, i32Type);
});
- // Convert MemDescType into flattened MemRefType for SLM
+ // Convert MemDescType into i32 for SLM
typeConverter.addConversion([&](xegpu::MemDescType type) -> Type {
- Type elemTy = type.getElementType();
- int numElems = type.getNumElements();
- return MemRefType::get(numElems, elemTy, AffineMap(), 3);
+ return IntegerType::get(&getContext(), 32);
});
typeConverter.addConversion([&](MemRefType type) -> Type {
diff --git a/mlir/test/Conversion/XeGPUToXeVM/loadstore_matrix.mlir b/mlir/test/Conversion/XeGPUToXeVM/loadstore_matrix.mlir
index d4cb493271d0d..5b3776ba0039c 100644
--- a/mlir/test/Conversion/XeGPUToXeVM/loadstore_matrix.mlir
+++ b/mlir/test/Conversion/XeGPUToXeVM/loadstore_matrix.mlir
@@ -4,8 +4,8 @@ gpu.module @test_kernel [#xevm.target<chip = "pvc">] {
// e.g. for mem_desc<32x32xf16, @strides=[1, 16]>
// its memory layout tuple is (blocked shape = [1,1,32,32],strides=[1024,1024,32,1])
- //CHECK-LABEL: load_store_matrix_1
- gpu.func @load_store_matrix_1(%arg0: memref<4096xi8, 3>) -> f32 {
+ //CHECK-LABEL: load_store_matrix_plain
+ gpu.func @load_store_matrix_plain(%arg0: memref<4096xi8, 3>) -> f32 {
%0 = xegpu.create_mem_desc %arg0 : memref<4096xi8, 3> -> !xegpu.mem_desc<32x32xf32>
//CHECK: %[[TID:.*]] = gpu.thread_id x
@@ -26,12 +26,40 @@ gpu.module @test_kernel [#xevm.target<chip = "pvc">] {
gpu.return %1: f32
}
+ //CHECK-LABEL: load_store_matrix_plain_2d_input
+ gpu.func @load_store_matrix_plain_2d_input(%arg0: memref<8192xi8, 3>) -> f32 {
+ %c0 = arith.constant 0 : index
+ %view = memref.view %arg0[%c0][]: memref<8192xi8, 3> to memref<64x32xf32, 3>
+
+ %subview = memref.subview %view[32, 0] [32, 32] [1, 1] : memref<64x32xf32, 3> to memref<32x32xf32, strided<[32, 1], offset: 1024>, 3>
+
+ %0 = xegpu.create_mem_desc %subview : memref<32x32xf32, strided<[32, 1], offset: 1024>, 3> -> !xegpu.mem_desc<32x32xf32>
+
+ //CHECK: %[[TID:.*]] = gpu.thread_id x
+ //CHECK: %[[C1:.*]] = arith.constant 1 : index
+ //CHECK: %[[MUL1:.*]] = arith.muli %[[TID]], %[[C1]] : index
+ //CHECK: %[[C4:.*]] = arith.constant 4 : i32
+ //CHECK: %[[MUL2:.*]] = arith.muli {{.*}}, %[[C4]] : i32
+ //CHECK: llvm.load {{.*}} : !llvm.ptr<3> -> f32
+
+ %tid_x = gpu.thread_id x
+
+ %1 = xegpu.load_matrix %0[%c0, %tid_x]: !xegpu.mem_desc<32x32xf32>, index, index -> f32
+
+ //CHECK: llvm.store {{.*}}, {{.*}} : f32, !llvm.ptr<3>
+
+ xegpu.store_matrix %1, %0[%c0, %tid_x]: f32, !xegpu.mem_desc<32x32xf32>, index, index
+
+ gpu.return %1: f32
+ }
+
+
// e.g. for mem_desc<32x64xf16, @block=[16, 16], @strides=[1, 32]>
// its memory layout tuple is ([2,4,16,16],[256,512,1,16])
- //CHECK-LABEL: load_store_matrix_2
- gpu.func @load_store_matrix_2(%arg0: memref<4096xi8, 3>) -> f16 {
+ //CHECK-LABEL: load_store_matrix_blocked_strided
+ gpu.func @load_store_matrix_blocked_strided(%arg0: memref<4096xi8, 3>) -> f16 {
%0 = xegpu.create_mem_desc %arg0 : memref<4096xi8, 3> -> !xegpu.mem_desc<32x64xf16, #xegpu.mem_layout<stride = [1, 32], block = [16, 16]>>
- //CHECK: %[[c0:.*]] = arith.constant 0 : index
+
//CHECK: %[[tid_x:.*]] = gpu.thread_id x
//CHECK: %[[c13:.*]] = arith.constant 13 : index
//CHECK: %[[c16:.*]] = arith.constant 16 : index
@@ -39,7 +67,7 @@ gpu.module @test_kernel [#xevm.target<chip = "pvc">] {
//CHECK: %[[offsetx_1:.*]] = arith.remsi %[[c13]], %[[c16]] : index
//CHECK: %[[offsety_0:.*]] = arith.divsi %[[tid_x]], %[[c16]] : index
//CHECK: %[[offsety_1:.*]] = arith.remsi %[[tid_x]], %[[c16]] : index
-
+ //CHECK: %[[c0:.*]] = arith.constant 0 : index
//CHECK: %[[c256:.*]] = arith.constant 256 : index
//CHECK: %[[mul0:.*]] = arith.muli %[[offsetx_0]], %[[c256]] : index
//CHECK: %[[add0:.*]] = arith.addi %[[mul0]], %[[c0]] : index
@@ -68,10 +96,11 @@ gpu.module @test_kernel [#xevm.target<chip = "pvc">] {
// e.g. for mem_desc<32x64xf16, @block=[16, 16]>
// its memory layout tuple is ([2,4,16,16],[1024,256,16,1])
- //CHECK-LABEL: load_store_matrix_3
- gpu.func @load_store_matrix_3(%arg0: memref<4096xi8, 3>) -> f16 {
- //CHECK: %[[c0:.*]] = arith.constant 0 : index
- //CHECK: %[[view:.*]] = memref.view %arg0[%[[c0]]][] : memref<4096xi8, 3> to memref<2048xf16, 3>
+ //CHECK-LABEL: load_store_matrix_blocked_nostride
+ gpu.func @load_store_matrix_blocked_nostride(%arg0: memref<4096xi8, 3>) -> f16 {
+
+ //CHECK: %[[intptr:.*]] = memref.extract_aligned_pointer_as_index %arg0 : memref<4096xi8, 3> -> index
+ //CHECK: %[[basePtrI64:.*]] = arith.index_castui %[[intptr]] : index to i32
%0 = xegpu.create_mem_desc %arg0 : memref<4096xi8, 3> -> !xegpu.mem_desc<32x64xf16, #xegpu.mem_layout<block = [16, 16]>>
//CHECK: %[[tid_x:.*]] = gpu.thread_id x
@@ -79,13 +108,12 @@ gpu.module @test_kernel [#xevm.target<chip = "pvc">] {
%tid_x = gpu.thread_id x
%c19 = arith.constant 19: index
- //CHECK: %[[intptr:.*]] = memref.extract_aligned_pointer_as_index %[[view]] : memref<2048xf16, 3> -> index
- //CHECK: %[[basePtrI64:.*]] = arith.index_castui %[[intptr]] : index to i32
//CHECK: %[[c16:.*]] = arith.constant 16 : index
//CHECK: %[[offsetx_0:.*]] = arith.divsi %[[c19]], %[[c16]] : index
//CHECK: %[[offsetx_1:.*]] = arith.remsi %[[c19]], %[[c16]] : index
//CHECK: %[[offsety_0:.*]] = arith.divsi %[[tid_x]], %[[c16]] : index
//CHECK: %[[offsety_1:.*]] = arith.remsi %[[tid_x]], %[[c16]] : index
+ //CHECK: %[[c0:.*]] = arith.constant 0 : index
//CHECK: %[[c1024:.*]] = arith.constant 1024 : index
//CHECK: %[[mul0:.*]] = arith.muli %[[offsetx_0]], %[[c1024]] : index
//CHECK: %[[add0:.*]] = arith.addi %[[mul0]], %[[c0]] : index
@@ -97,7 +125,6 @@ gpu.module @test_kernel [#xevm.target<chip = "pvc">] {
//CHECK: %[[c1:.*]] = arith.constant 1 : index
//CHECK: %[[mul3:.*]] = arith.muli %[[offsety_1]], %[[c1]] : index
//CHECK: %[[add3:.*]] = arith.addi %[[mul3]], %[[add2]] : index
-
//CHECK: %[[loaded:.*]] = llvm.load {{.*}} : !llvm.ptr<3> -> f16
%1 = xegpu.load_matrix %0[%c19, %tid_x]: !xegpu.mem_desc<32x64xf16, #xegpu.mem_layout<block = [16, 16]>>, index, index -> f16
@@ -110,19 +137,17 @@ gpu.module @test_kernel [#xevm.target<chip = "pvc">] {
// e.g. for mem_desc<32x64xf16, @block=[16, 16], @strides=[1, 16]>
// its memory layout tuple is ([2,4,16,16],[256,512,1,16])
- //CHECK-LABEL: load_store_matrix_4
- gpu.func @load_store_matrix_4(%arg0: memref<4096xi8, 3>) -> vector<8xf16> {
+ //CHECK-LABEL: load_store_matrix_blocked_strided_return_vector
+ gpu.func @load_store_matrix_blocked_strided_return_vector(%arg0: memref<4096xi8, 3>) -> vector<8xf16> {
%0 = xegpu.create_mem_desc %arg0 : memref<4096xi8, 3> -> !xegpu.mem_desc<32x64xf16, #xegpu.mem_layout<stride = [1, 32], block = [16, 16]>>
- //CHECK: %[[c0:.*]] = arith.constant 0 : index
//CHECK: %[[tid_x:.*]] = gpu.thread_id x
-
//CHECK: %[[c16:.*]] = arith.constant 16 : index
//CHECK: %[[offsetx_0:.*]] = arith.divsi %[[c16]], %[[c16]] : index
//CHECK: %[[offsetx_1:.*]] = arith.remsi %[[c16]], %[[c16]] : index
//CHECK: %[[offsety_0:.*]] = arith.divsi %[[tid_x]], %[[c16]] : index
//CHECK: %[[offsety_1:.*]] = arith.remsi %[[tid_x]], %[[c16]] : index
-
+ //CHECK: %[[c0:.*]] = arith.constant 0 : index
//CHECK: %[[c256:.*]] = arith.constant 256 : index
//CHECK: %[[mul0:.*]] = arith.muli %[[offsetx_0]], %[[c256]] : index
//CHECK: %[[add0:.*]] = arith.addi %[[mul0]], %[[c0]] : index
@@ -150,25 +175,24 @@ gpu.module @test_kernel [#xevm.target<chip = "pvc">] {
// e.g. for mem_desc<32x64xf16, @block=[16, 16]>
// its memory layout tuple is ([2,4,16,16],[1024,256,16,1])
- //CHECK-LABEL: load_store_matrix_5
- gpu.func @load_store_matrix_5(%arg0: memref<4096xi8, 3>) -> vector<8xf16> {
- //CHECK: %[[c0:.*]] = arith.constant 0 : index
- //CHECK: %[[view:.*]] = memref.view %arg0[%[[c0]]][] : memref<4096xi8, 3> to memref<2048xf16, 3>
-
- %0 = xegpu.create_mem_desc %arg0 : memref<4096xi8, 3> -> !xegpu.mem_desc<32x64xf16, #xegpu.mem_layout<block = [16, 16]>>
-
+ //CHECK-LABEL: load_store_matrix_blocked_subgroupblockio
+ gpu.func @load_store_matrix_blocked_subgroupblockio(%arg0: memref<4096xi8, 3>) -> vector<8xf16> {
+
+ //CHECK: %[[intptr:.*]] = memref.extract_aligned_pointer_as_index %arg0 : memref<4096xi8, 3> -> index
+ //CHECK: %[[basePtrI64:.*]] = arith.index_castui %[[intptr]] : index to i32
+ %0 = xegpu.create_mem_desc %arg0 : memref<4096xi8, 3> -> !xegpu.mem_desc<32x64xf16, #xegpu.mem_layout<block = [16, 16]>>
+
+
//CHECK: %[[c16:.*]] = arith.constant 16 : index
//CHECK: %[[c48:.*]] = arith.constant 48 : index
-
%c16 = arith.constant 16 : index
%c48 = arith.constant 48 : index
- //CHECK: %[[intptr:.*]] = memref.extract_aligned_pointer_as_index %[[view]] : memref<2048xf16, 3> -> index
- //CHECK: %[[basePtrI64:.*]] = arith.index_castui %[[intptr]] : index to i32
//CHECK: %[[offset0:.*]] = arith.divsi %[[c16]], %[[c16]] : index
//CHECK: %[[offset1:.*]] = arith.remsi %[[c16]], %[[c16]] : index
//CHECK: %[[offset2:.*]] = arith.divsi %[[c48]], %[[c16]] : index
//CHECK: %[[offset3:.*]] = arith.remsi %[[c48]], %[[c16]] : index
+ //CHECK: %[[c0:.*]] = arith.constant 0 : index
//CHECK: %[[c1024:.*]] = arith.constant 1024 : index
//CHECK: %[[mul0:.*]] = arith.muli %[[offset0]], %[[c1024]] : index
//CHECK: %[[add0:.*]] = arith.addi %[[mul0]], %[[c0]] : index
diff --git a/mlir/test/Dialect/XeGPU/ops.mlir b/mlir/test/Dialect/XeGPU/ops.mlir
index 9b3829664108d..1e9738f44bb66 100644
--- a/mlir/test/Dialect/XeGPU/ops.mlir
+++ b/mlir/test/Dialect/XeGPU/ops.mlir
@@ -834,6 +834,27 @@ gpu.func @create_mem_desc_with_stride() {
gpu.return
}
+
+// CHECK-LABEL: gpu.func @create_mem_desc_from_2d_memref({{.*}}) {
+gpu.func @create_mem_desc_from_2d_memref() {
+ //CHECK: [[alloc:%.+]] = memref.alloca() {alignment = 1024 : i64} : memref<16x64xf16, 3>
+ //CHECK: [[mdesc:%.+]] = xegpu.create_mem_desc [[alloc]] : memref<16x64xf16, 3> -> !xegpu.mem_desc<16x64xf16>
+ %m = memref.alloca() {alignment = 1024} : memref<16x64xf16, 3>
+ %mem_desc = xegpu.create_mem_desc %m : memref<16x64xf16, 3> -> !xegpu.mem_desc<16x64xf16>
+ gpu.return
+}
+
+// CHECK-LABEL: gpu.func @create_mem_desc_with_stride_from_2d_memref({{.*}}) {
+gpu.func @create_mem_desc_with_stride_from_2d_memref() {
+ //CHECK: %[[ALLOC:.+]] = memref.alloca() {alignment = 1024 : i64} : memref<32x64xf16, 3>
+ //CHECK: %[[SUBVIEW:.+]] = memref.subview %[[ALLOC]][16, 0] [16, 64] [1, 1] : memref<32x64xf16, 3> to memref<16x64xf16, strided<[64, 1], offset: 1024>, 3>
+ //CHECK: %{{.+}} = xegpu.create_mem_desc %[[SUBVIEW]] : memref<16x64xf16, strided<[64, 1], offset: 1024>, 3> -> !xegpu.mem_desc<16x64xf16, #xegpu.mem_layout<stride = [1, 16]>>
+ %m = memref.alloca() {alignment = 1024} : memref<32x64xf16, 3>
+ %m_sub = memref.subview %m[16, 0][16, 64][1,1] : memref<32x64xf16, 3> to memref<16x64xf16, strided<[64, 1], offset: 1024>, 3>
+ %mem_desc = xegpu.create_mem_desc %m_sub : memref<16x64xf16, strided<[64, 1], offset: 1024>, 3> -> !xegpu.mem_desc<16x64xf16, #xegpu.mem_layout<stride = [1, 16]>>
+ gpu.return
+}
+
// CHECK: gpu.func @load_matrix([[ARG0:%.+]]: !xegpu.mem_desc<16x64xf16>)
gpu.func @load_matrix(%arg0: !xegpu.mem_desc<16x64xf16>) {
// CHECK: xegpu.load_matrix [[ARG0]][8, 8] : !xegpu.mem_desc<16x64xf16> -> vector<8x16xf16>
|
| gpu.func @load_store_matrix_blocked_subgroupblockio(%arg0: memref<4096xi8, 3>) -> vector<8xf16> { | ||
|
|
||
| //CHECK: %[[intptr:.*]] = memref.extract_aligned_pointer_as_index %arg0 : memref<4096xi8, 3> -> index | ||
| //CHECK: %[[basePtrI64:.*]] = arith.index_castui %[[intptr]] : index to i32 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nit: wrong type in the name, maybe a generic basePtr would be better?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
fixed.
charithaintc
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM
🐧 Linux x64 Test Results
|
silee2
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM
|
LLVM Buildbot has detected a new failure on builder Full details are available at: https://lab.llvm.org/buildbot/#/builders/169/builds/17223 Here is the relevant piece of the build log for the reference |
This PR relax the create_mem_desc's restriction on source memref, allowing it to be a 2d memref.